{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Engine Caching\n\nAs model sizes increase, the cost of compilation will as well. With AOT methods\nlike ``torch.dynamo.compile``, this cost is paid upfront. However if the weights\nchange, the session ends or you are using JIT methods like ``torch.compile``, as\ngraphs get invalidated they get re-compiled, this cost will get paid repeatedly.\nEngine caching is a way to mitigate this cost by saving constructed engines to disk\nand re-using them when possible. This tutorial demonstrates how to use engine caching\nwith TensorRT in PyTorch. Engine caching can significantly speed up subsequent model\ncompilations reusing previously built TensorRT engines.\n\nWe'll explore two approaches:\n\n    1. Using torch_tensorrt.dynamo.compile\n    2. Using torch.compile with the TensorRT backend\n\nThe example uses a pre-trained ResNet18 model and shows the\ndifferences between compilation without caching, with caching enabled,\nand when reusing cached engines.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nfrom typing import Dict, Optional\n\nimport numpy as np\nimport torch\nimport torch_tensorrt as torch_trt\nimport torchvision.models as models\nfrom torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH\nfrom torch_tensorrt.dynamo._engine_cache import BaseEngineCache\n\nnp.random.seed(0)\ntorch.manual_seed(0)\n\nmodel = models.resnet18(pretrained=True).to(\"cuda\").eval()\nenabled_precisions = {torch.float}\nmin_block_size = 1\nuse_python_runtime = False\n\n\ndef remove_timing_cache(path=TIMING_CACHE_PATH):\n    if os.path.exists(path):\n        os.remove(path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Engine Caching for JIT Compilation\n\nThe primary goal of engine caching is to help speed up JIT workflows. ``torch.compile``\nprovides a great deal of flexibility in model construction which makes it a good\nfirst tool to try when looking to speed up your workflow. However, historically\nthe cost of compilation and in particular recompilation has been a barrier to entry\nfor many users. If for some reason a subgraph gets invalidated, that graph is reconstructed\nscratch prior to the addition of engine caching. Now as engines are constructed, with ``cache_built_engines=True``,\nengines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If\nin a subsequent compilation, either as part of this session or a new session, the cache will\npull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.\nAs such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),\nthe engine must be refittable (``immutable_weights=False``). See `refit_engine_example` for more details.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def torch_compile(iterations=3):\n    times = []\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n\n    # The 1st iteration is to measure the compilation time without engine caching\n    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.\n    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.\n    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.\n    for i in range(iterations):\n        inputs = [torch.rand((100, 3, 224, 224)).to(\"cuda\")]\n        # remove timing cache and reset dynamo just for engine caching messurement\n        remove_timing_cache()\n        torch._dynamo.reset()\n\n        if i == 0:\n            cache_built_engines = False\n            reuse_cached_engines = False\n        else:\n            cache_built_engines = True\n            reuse_cached_engines = True\n\n        start.record()\n        compiled_model = torch.compile(\n            model,\n            backend=\"tensorrt\",\n            options={\n                \"use_python_runtime\": True,\n                \"enabled_precisions\": enabled_precisions,\n                \"min_block_size\": min_block_size,\n                \"immutable_weights\": False,\n                \"cache_built_engines\": cache_built_engines,\n                \"reuse_cached_engines\": reuse_cached_engines,\n            },\n        )\n        with torch.no_grad():\n            compiled_model(*inputs)  # trigger the compilation\n        end.record()\n        torch.cuda.synchronize()\n        times.append(start.elapsed_time(end))\n\n    print(\"----------------torch_compile----------------\")\n    print(\"disable engine caching, used:\", times[0], \"ms\")\n    print(\"enable engine caching to cache engines, used:\", times[1], \"ms\")\n    print(\"enable engine caching to reuse engines, used:\", times[2], \"ms\")\n\n\ntorch_compile()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Engine Caching for AOT Compilation\nSimilarly to the JIT workflow, AOT workflows can benefit from engine caching.\nAs the same architecture or common subgraphs get recompiled, the cache will pull\npreviously built engines and refit the weights.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def dynamo_compile(iterations=3):\n    times = []\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n\n    example_inputs = (torch.randn((100, 3, 224, 224)).to(\"cuda\"),)\n    # Mark the dim0 of inputs as dynamic\n    batch = torch.export.Dim(\"batch\", min=1, max=200)\n    exp_program = torch.export.export(\n        model, args=example_inputs, dynamic_shapes={\"x\": {0: batch}}\n    )\n\n    # The 1st iteration is to measure the compilation time without engine caching\n    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.\n    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.\n    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.\n    for i in range(iterations):\n        inputs = [torch.rand((100 + i, 3, 224, 224)).to(\"cuda\")]\n        remove_timing_cache()  # remove timing cache just for engine caching messurement\n        if i == 0:\n            cache_built_engines = False\n            reuse_cached_engines = False\n        else:\n            cache_built_engines = True\n            reuse_cached_engines = True\n\n        start.record()\n        trt_gm = torch_trt.dynamo.compile(\n            exp_program,\n            tuple(inputs),\n            use_python_runtime=use_python_runtime,\n            enabled_precisions=enabled_precisions,\n            min_block_size=min_block_size,\n            immutable_weights=False,\n            cache_built_engines=cache_built_engines,\n            reuse_cached_engines=reuse_cached_engines,\n            engine_cache_size=1 << 30,  # 1GB\n        )\n        # output = trt_gm(*inputs)\n        end.record()\n        torch.cuda.synchronize()\n        times.append(start.elapsed_time(end))\n\n    print(\"----------------dynamo_compile----------------\")\n    print(\"disable engine caching, used:\", times[0], \"ms\")\n    print(\"enable engine caching to cache engines, used:\", times[1], \"ms\")\n    print(\"enable engine caching to reuse engines, used:\", times[2], \"ms\")\n\n\ndynamo_compile()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Custom Engine Cache\n\nBy default, the engine cache is stored in the system's temporary directory. Both the cache directory and\nsize limit can be customized by passing ``engine_cache_dir`` and ``engine_cache_size``.\nUsers can also define their own engine cache implementation by extending the ``BaseEngineCache`` class.\nThis allows for remote or shared caching if so desired.\n\nThe custom engine cache should implement the following methods:\n  - ``save``: Save the engine blob to the cache.\n  - ``load``: Load the engine blob from the cache.\n\nThe hash provided by the cache systen is a weight agnostic hash of the originating PyTorch subgraph (post lowering).\nThe blob contains a serialized engine, calling spec data, and weight map information in the pickle format\n\nBelow is an example of a custom engine cache implementation that implents a ``RAMEngineCache``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class RAMEngineCache(BaseEngineCache):\n    def __init__(\n        self,\n    ) -> None:\n        \"\"\"\n        Constructs a user held engine cache in memory.\n        \"\"\"\n        self.engine_cache: Dict[str, bytes] = {}\n\n    def save(\n        self,\n        hash: str,\n        blob: bytes,\n    ):\n        \"\"\"\n        Insert the engine blob to the cache.\n\n        Args:\n            hash (str): The hash key to associate with the engine blob.\n            blob (bytes): The engine blob to be saved.\n\n        Returns:\n            None\n        \"\"\"\n        self.engine_cache[hash] = blob\n\n    def load(self, hash: str) -> Optional[bytes]:\n        \"\"\"\n        Load the engine blob from the cache.\n\n        Args:\n            hash (str): The hash key of the engine to load.\n\n        Returns:\n            Optional[bytes]: The engine blob if found, None otherwise.\n        \"\"\"\n        if hash in self.engine_cache:\n            return self.engine_cache[hash]\n        else:\n            return None\n\n\ndef torch_compile_my_cache(iterations=3):\n    times = []\n    engine_cache = RAMEngineCache()\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n\n    # The 1st iteration is to measure the compilation time without engine caching\n    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.\n    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.\n    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.\n    for i in range(iterations):\n        inputs = [torch.rand((100, 3, 224, 224)).to(\"cuda\")]\n        # remove timing cache and reset dynamo just for engine caching messurement\n        remove_timing_cache()\n        torch._dynamo.reset()\n\n        if i == 0:\n            cache_built_engines = False\n            reuse_cached_engines = False\n        else:\n            cache_built_engines = True\n            reuse_cached_engines = True\n\n        start.record()\n        compiled_model = torch.compile(\n            model,\n            backend=\"tensorrt\",\n            options={\n                \"use_python_runtime\": True,\n                \"enabled_precisions\": enabled_precisions,\n                \"min_block_size\": min_block_size,\n                \"immutable_weights\": False,\n                \"cache_built_engines\": cache_built_engines,\n                \"reuse_cached_engines\": reuse_cached_engines,\n                \"custom_engine_cache\": engine_cache,\n            },\n        )\n        with torch.no_grad():\n            compiled_model(*inputs)  # trigger the compilation\n        end.record()\n        torch.cuda.synchronize()\n        times.append(start.elapsed_time(end))\n\n    print(\"----------------torch_compile----------------\")\n    print(\"disable engine caching, used:\", times[0], \"ms\")\n    print(\"enable engine caching to cache engines, used:\", times[1], \"ms\")\n    print(\"enable engine caching to reuse engines, used:\", times[2], \"ms\")\n\n\ntorch_compile_my_cache()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}